#!/usr/bin/env python
from __future__ import print_function
import sys
import os
import re
import argparse
import pandas as pd
import numpy as np
from collections import defaultdict, namedtuple
my_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0]))
sys.path.insert(0, os.path.join(my_bin_dir, "../tools"))
from bx.intervals.cluster import ClusterTree
from transcripts import gene_pred_iterator
from intervals import ChromosomeInterval
from rangeFinder import RangeFinder

# Notes:
#  - use to write misc_feeature for each gap, however this would generated:
#     ERROR: valid [SEQ_FEAT.MissingGeneXref] Feature overlapped by 2 identical-length genes but has no cross-reference
#    If a locus tag was added to like the misc_feature to a gene, we get lots more
#    SEQ_FEAT.GeneXrefStrandProblem errors, since the tbl format as no way to annotate
#    a one base negative strand.  So we just turned these into comments.


def parse_args():
    desc = """convert a CAT generated genePred and associated information to an NCBI submission table file
    """
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--verbose', action="store_true", default=False,
                        help="run on some tracing")
    parser.add_argument("cat_genepred",
                        help="input CAT genePred results")
    parser.add_argument("cat_genepred_info",
                        help="associated genePred info")
    parser.add_argument("locus_tag_prefix",
                        help="locus tag prefix to use")
    parser.add_argument("ncbi_tbl_file",
                        help="output NCBI table file")
    parser.add_argument('--min-cds-size', default=90, type=int, help='discard CDS with less than this many bases.')
    args = parser.parse_args()
    global verbose
    verbose = args.verbose
    global min_cds_size
    min_cds_size = args.min_cds_size
    return args


# this maps Ensembl/GENCODE biotypes to feature-level identifiers
# non-transcribed pseudogenes are called mRNAs as they are mostly
# annotated from protein alignments.  Transcribed pseudogenes are called
# ncRNAs, because they actually produce a transcript that could
# have non-coding function.
biotype_map = {
    '3prime_overlapping_ncRNA': 'ncRNA',
    '3prime_overlapping_ncrna': 'ncRNA',
    'IG_C_gene': 'C_segment',
    'IG_C_pseudogene': 'C_segment',
    'IG_D_gene': 'D_region',
    'IG_D_pseudogene': 'D_region',
    'IG_J_gene': 'J_segment',
    'IG_J_pseudogene': 'J_segment',
    'IG_V_gene': 'V_region',
    'IG_V_pseudogene': 'V_region',
    'IG_pseudogene': 'J_segment',
    'Mt_rRNA': 'rRNA',
    'Mt_tRNA': 'tRNA',
    'TEC': 'mRNA',
    'TR_C_gene': 'C_region',
    'TR_C_pseudogene': 'C_region',
    'TR_J_gene': 'J_segment',
    'TR_J_pseudogene': 'J_segment',
    'TR_V_gene': 'V_region',
    'TR_V_pseudogene': 'V_region',
    'antisense': 'ncRNA',
    'antisense_RNA': 'ncRNA',
    'bidirectional_promoter_lncRNA': 'ncRNA',
    'bidirectional_promoter_lncrna': 'ncRNA',
    'lincRNA': 'ncRNA',
    'macro_lncRNA': 'ncRNA',
    'miRNA': 'ncRNA',
    'misc_RNA': 'ncRNA',
    'non_coding': 'ncRNA',
    'non_stop_decay': 'mRNA',
    'nonsense_mediated_decay': 'mRNA',
    'polymorphic_pseudogene': 'mRNA',  # coding in some humans
    'processed_pseudogene': 'ncRNA',
    'processed_transcript': 'ncRNA',
    'protein_coding': 'mRNA',
    'pseudogene': 'ncRNA',
    'rRNA': 'rRNA',
    'retained_intron': 'ncRNA',
    'ribozyme': 'ncRNA',
    'sRNA': 'ncRNA',
    'scaRNA': 'ncRNA',
    'scRNA': 'ncRNA',
    'sense_intronic': 'ncRNA',
    'sense_overlapping': 'ncRNA',
    'snRNA': 'ncRNA',
    'snoRNA': 'ncRNA',
    'transcribed_processed_pseudogene': 'ncRNA',
    'transcribed_unitary_pseudogene': 'ncRNA',
    'transcribed_unprocessed_pseudogene': 'ncRNA',
    'translated_unprocessed_pseudogene': 'ncRNA',
    'translated_processed_pseudogene': 'ncRNA',
    'unitary_pseudogene': 'ncRNA',   # only in human
    'unknown_likely_coding': 'mRNA',
    'unprocessed_pseudogene': 'ncRNA',
    'vaultRNA': 'ncRNA'}


# for non-coding RNAs, this defines the ncrna_class key/value pair
ncrna_class = {
    '3prime_overlapping_ncRNA': 'other',
    '3prime_overlapping_ncrna': 'other',
    'antisense_RNA': 'antisense',
    'bidirectional_promoter_lncrna': 'lncRNA',
    'bidirectional_promoter_lncRNA': 'lncRNA',
    'lincRNA': 'lncRNA',
    'macro_lncRNA': 'lncRNA',
    'miRNA': 'miRNA',
    'processed_pseudogene': 'other',
    'processed_transcript': 'other',
    'pseudogene': 'other',
    'retained_intron': 'other',
    'ribozyme': 'ribozyme',
    'sRNA': 'other',
    'scRNA': 'other',
    'scaRNA': 'other',
    'sense_intronic': 'other',
    'sense_overlapping': 'other',
    'snRNA': 'snRNA',
    'snoRNA': 'snoRNA',
    'transcribed_processed_pseudogene': 'other',
    'transcribed_unitary_pseudogene': 'other',
    'transcribed_unprocessed_pseudogene': 'other',
    'translated_processed_pseudogene': 'other',
    'translated_unprocessed_pseudogene': 'other',
    'unitary_pseudogene': 'other',
    'unprocessed_pseudogene': 'other',
    'antisense': 'other',
    'misc_RNA': 'other',
    'non_coding': 'other',
    'vaultRNA': 'vault_RNA'}

# for pseudogenes, this defines the pseudogene key/value pair
pseudo_map = {
    'IG_C_pseudogene': 'unknown',
    'IG_D_pseudogene': 'unknown',
    'IG_J_pseudogene': 'unknown',
    'IG_V_pseudogene': 'unknown',
    'IG_pseudogene': 'unknown',
    'TR_C_pseudogene': 'unknown',
    'TR_J_pseudogene': 'unknown',
    'TR_V_pseudogene': 'unknown',
    'processed_pseudogene': 'processed',
    'pseudogene': 'unknown',
    'transcribed_processed_pseudogene': 'processed',
    'transcribed_unprocessed_pseudogene': 'unprocessed',
    'translated_unprocessed_pseudogene': 'unprocessed',
    'unprocessed_pseudogene': 'unprocessed',
    'translated_processed_pseudogene': 'processed'}

# INSDC features that are immuno components
immuno_features = frozenset(['C_region',
                             'C_segment',
                             'D_region',
                             'J_segment',
                             'V_region'])


def check_defs():
    "verify biotype definitions are semi-sane"
    for biotype in biotype_map.keys():
        if biotype_map[biotype] == 'ncRNA':
            assert biotype in ncrna_class, biotype
        elif biotype_map[biotype] != 'ncRNA':
            assert biotype not in ncrna_class, biotype
        elif biotype_map[biotype] == 'pseudo':
            assert biotype in pseudo_map, biotype


class GeneData(object):
    """Annotations for a single gene"""
    def __init__(self, gene_id):
        self.gene_id = gene_id
        self.txds_by_id = dict()  # TranscriptData by trans_id
        self.__region = None # lazy
        self.__txds = None  # lazy sort

    @property
    def txds(self):
        if self.__txds is None:
            self.__txds = sorted(self.txds_by_id.values(), key=lambda t:(t.tx.start, t.tx.name))
        return self.__txds

    @property
    def chrom(self):
        return self.__txds[0].tx.chromosome

    @property
    def strand(self):
        return self.__txds[0].tx.strand

    @property
    def start(self):
        return min([txd.tx.start for txd in self.txds])

    @property
    def stop(self):
        return max([txd.tx.stop for txd in self.txds])

    @property
    def region(self):
        if self.__region is None:
            self.__region = ChromosomeInterval(self.chrom, self.start, self.stop, self.strand)
        return self.__region

    def add(self, txd):
        if txd.transcript_id in self.txds_by_id:
            raise Exception("Duplicate transcript id: {} {}".format(self.gene_id, txd.transcript_id))
        self.txds_by_id[txd.transcript_id] = txd
        self.__txds = None


class TranscriptData(namedtuple("TranscriptData", ("tx", "attrs"))):
    """transcript annotation and attributes"""
    __slots__ = ()

    @property
    def transcript_id(self):
        return self.tx.name

    @property
    def exons(self):
        "get unique exons"
        exons = set()
        for exon in self.tx.exon_intervals:
            exons.add(exon)
        return exons


def load_annotations(gp_file, gp_info_file):
    """load all transcript into OrderedDict of (chrom, gene) of TranscriptAttrs"""
    by_chrom_gene_id = defaultdict(dict)
    attrs = pd.read_csv(gp_info_file, sep='\t')
    attrs = attrs.set_index(['transcript_id'])
    for tx in gene_pred_iterator(gp_file):
        gene = by_chrom_gene_id[tx.chromosome].get(tx.name2)
        if gene is None:
            gene = by_chrom_gene_id[tx.chromosome][tx.name2] = GeneData(tx.name2)
        gene.add(TranscriptData(tx, attrs.ix[tx.name]))
    return by_chrom_gene_id


class NcbiTblWriter(object):
    """class to write NCBI tables.  Due to the need to examine transcripts
    before genes,  this queues entries to be written.  The push function
    saves existing entries and starts a new ones.  The pop() function
    write current entries and removes them from the top of stack."""

    def __init__(self, fh):
        self.fh = fh
        self.recs = []
        self.saved = []

    def push(self):
        self.saved.append(self.recs)
        self.recs = []

    def pop(self):
        "replace top with pushed"
        assert len(self.recs) == 0
        self.recs = self.saved.pop()

    def flush(self):
        "write to file"
        for rec in self.recs:
            print(rec, file=self.fh)
        self.recs = []

    def start_seq(self, seqname):
        self.recs.append(">Features {}".format(seqname))

    def __write_feature_region(self, region, feature_key, start_incmpl, end_incmpl):
        "write one region range from list, feature_key should be None for continuations"
        start, stop = str(region.start + 1), str(region.stop)
        if region.strand == '-':
            start, stop = stop, start
        if start_incmpl:
            start = "<{}".format(start)
        if end_incmpl:
            stop = ">{}".format(stop)
        row = [start, stop]
        if feature_key is not None:
            row.append(feature_key)
        self.recs.append("\t".join(row))

    def __merge_adjacent_features(self, regions):
        merged = [regions[0]]
        for region in regions[1:]:
            if merged[-1].stop == region.start:
                merged[-1] = ChromosomeInterval(region.chromosome, merged[-1].start, region.stop, region.strand)
            else:
                merged.append(region)
        return merged

    def write_feature(self, regions, feature_key, start_incmpl=False, end_incmpl=False):
        """Write feature definition rows.  Ranges is a list in the form
        ((start1, end1), (start2, end2), ...), for each regions of the feature.
        The range are zero-based, half-open, and coordinates will be swapped on negative strand."""
        strand = regions[0].strand
        if feature_key != "CDS":
            regions = self.__merge_adjacent_features(regions)

        # correct order for strand, use start/end incmpl as True only for the appropriate regions
        regions = list(sorted(regions, key=lambda r: r.start if strand == '+' else -r.stop))
        if strand == '-':
            start_incmpl, end_incmpl = end_incmpl, start_incmpl
        self.__write_feature_region(regions[0], feature_key,
                                    start_incmpl=start_incmpl,
                                    end_incmpl=(end_incmpl and (len(regions) == 1)))
        for region in regions[1:]:
            self.__write_feature_region(region, None,
                                        start_incmpl=None,
                                        end_incmpl=(end_incmpl and (region == regions[-1])))

    def write_locus_tag(self, locus_tag_prefix, gene):
        """write locus tag in the form ${locus_tag_prefix}_${gene_name}"""
        self.write_qualifier("locus_tag", '{}_{}'.format(locus_tag_prefix, get_gene_locus_tag(gene)))

    def write_qualifier(self, qualifier_key, qualifier_value=None):
        """write a qualifier with optional value"""
        row = ["", "", "", qualifier_key]
        if qualifier_value is not None:
            row += [qualifier_value]
        self.recs.append("\t".join(row))

    def write_note(self, note):
        self.write_qualifier("note", note)


class CdsSpec(object):
    """Data collected defining CDS"""
    def __init__(self, strand):
        self.strand = strand
        self.regions = []     # CDS regions, frame adjusted
        self.gaps = []
        # start/end completeness (DNA (+) strand)
        self.start_incmpl = self.end_incmpl = None

    def __str__(self):
        return "CDS: {} regions={}, gaps={}, incmpl={}/{}".format(self.strand, self.regions, self.gaps, self.start_incmpl, self.end_incmpl)

    @property
    def cds_len(self):
        return sum([len(e) for e in self.regions])

    def is_mult_three(self):
        "is CDS length a multiple of three?"
        return (self.cds_len % 3) == 0


def frame_incr(frame, amt=1):
    """increment frame by positive or negative amount"""
    if frame >= 0:
        return (frame + amt) % 3
    else:
        amt3 = (-amt) % 3
        return (frame - (amt - amt3)) % 3


def adjust_cds_start(cds_interval, expected_frame, frame):
    """adjust cds_interval to match the expected frame.  It is possible
    for the cds_interval to become zero"""
    amt = 0
    # this could be calculated rather than increment by in a loop, however
    # this is easier for the feeble minded
    while frame != expected_frame:
        frame = frame_incr(frame)
        amt += 1
    # min/max here avoids going negative, making a zero-length block
    if cds_interval.strand == '+':
        gap = ChromosomeInterval(cds_interval.chromosome, cds_interval.start, cds_interval.start + amt, cds_interval.strand)
        start = min(cds_interval.start + amt, cds_interval.stop)
        stop = cds_interval.stop
    else:
        gap = ChromosomeInterval(cds_interval.chromosome, cds_interval.stop - amt, cds_interval.stop, cds_interval.strand)
        start = cds_interval.start
        stop = max(cds_interval.stop - amt, cds_interval.start)
    cds_interval = ChromosomeInterval(cds_interval.chromosome, start, stop, cds_interval.strand)
    if verbose:
        print("  adj", cds_interval, "[{}]".format(len(cds_interval)), "frame", "{}..{}".format(frame, frame_incr(frame, len(cds_interval))), amt, "gap", gap, file=sys.stderr)
    return gap, cds_interval


def get_gene_locus_tag(gene):
    # remove species identifier from CAT names (Clint_Chim, etc)
    return gene.gene_id.split('_')[-1]


def get_transcript_base_id(txd):
    "remove species identifier from CAT names (Clint_Chim, etc)"
    return txd.tx.name.split('_')[-1]


def get_transcript_gnl_id(locus_tag_prefix, txd):
    # transcript ID is the gnl|${locus_tag_prefix}|${tx_name} format
    return 'gnl|{}|{}'.format(locus_tag_prefix, get_transcript_base_id(txd))


def get_transcript_feature(txd):
    if txd.tx.cds_size > 0:
        return 'mRNA'
    else:
        # for non-coding, make use of translation table above to get feature type
        return biotype_map[txd.attrs.transcript_biotype]


def is_non_coding(txd):
    return txd.attrs.transcript_biotype in ncrna_class


def trans_has_cds(txd):
    return txd.tx.cds_size > 0


def have_gene_name(gene):
    "is there anything for the gene name real or generated??"
    txd = gene.txds[0]
    return not pd.isnull(txd.attrs.source_gene_common_name)


def get_real_gene_name(gene):
    # GENCODE uses clone derived names for some genomic annotations, which NCBI doesn't
    # (understandably) like.  Drop these.
    txd = gene.txds[0]
    if re.match('^[A-Za-z]+[0-9]+\\.[0-9]+$', txd.attrs.source_gene_common_name):
        return None
    else:
        return txd.attrs.source_gene_common_name


def get_coding_product(gene, txd, trans_num):
    name = get_real_gene_name(gene)
    if name is None:
        name = get_transcript_base_id(txd)
    return "{} isoform {}".format(name, trans_num)


def get_ncrna_product(txd):
    return txd.attrs.transcript_biotype.replace('_', ' ')


def get_rrna_product(name):
    """Lookup table for translating rRNA names"""
    if name == '5S_rRNA':
        return '5S rRNA'
    elif name == '5_8S_rRNA' or name == 'RNA5-8SN5':
        return '5.8S rRNA'
    else:
        return name


def get_product(gene, txd, trans_num):
    if is_non_coding(txd):
        return get_ncrna_product(txd)
    elif get_transcript_feature(txd) == "tRNA":
        return "tRNA-Xxx"
    elif get_transcript_feature(txd) == 'rRNA':
        return get_rrna_product(get_real_gene_name(gene))
    else:
        return get_coding_product(gene, txd, trans_num)


def get_transcript_protein_id(locus_tag, txd, trans_num):
    return "gnl|{}|{}_{}_prot".format(locus_tag, get_transcript_base_id(txd), trans_num)


def write_protein_id(locus_tag_prefix, txd, trans_num, tblwr):
    tblwr.write_qualifier('protein_id', get_transcript_protein_id(locus_tag_prefix, txd, trans_num))


def write_product(locus_tag_prefix, gene, txd, trans_num, tblwr):
    "write product qualifier as needed"
    # if we have a source gene name, use it
    # Product tag must match the gene-level gene tag or the conversion program gets mad
    # it there is no source gene, it is a CGP annotations
    if not pd.isnull(txd.attrs.source_gene):
        tblwr.write_qualifier("product",  get_product(gene, txd, trans_num))
    if have_gene_name(gene) and (get_real_gene_name(gene) is None):
        # generated gene name
        tblwr.write_note("No associated gene symbol, Ensembl generated symbol was {}".format(txd.attrs.source_gene_common_name))
    if trans_has_cds(txd):
        write_protein_id(locus_tag_prefix, txd, trans_num, tblwr)


def make_exon_idx_iter(txd):
    "make iterator exon indexes in order of transcriptions"
    if txd.tx.strand  == '+':
        return xrange(0, len(txd.tx.exon_intervals))
    else:
        return xrange(len(txd.tx.exon_intervals) - 1, -1, -1)


def add_cds_region(cds_interval, frame, expected_frame, cds_spec):
    """Add the next CDS region to the lists, adding gaps to if there are indels,
    as NCBI doesn't explicitly support frame shifts. Returns next expected frame"""
    # FIXME: much of this code is now in transcripts.py, switch to using it.
    # FIXME: /start_codon could be used rather adjust CDS at start
    if verbose:
        print("  cds", cds_interval, "[{}]".format(len(cds_interval)), "frame", frame, "exframe", expected_frame, file=sys.stderr)

    # adjust for frame
    if frame != expected_frame:
        gap, cds_interval = adjust_cds_start(cds_interval, expected_frame, frame)
        cds_spec.gaps.append(gap)

    if len(cds_interval) != 0:
        cds_spec.regions.append(cds_interval)

    return frame_incr(expected_frame, len(cds_interval))


def compute_cds_completeness(txd, trans_start_incmpl, cds_spec):
    # trans_start_incmpl is in the direction of transcription
    # but CDS spec is DNA 5'
    cds_spec.start_incmpl = cds_spec.end_incmpl = False

    # incompleteness checks based on frame, source gene and start/stop codons in genome
    trans_end_incmpl = not cds_spec.is_mult_three()
    if txd.tx.strand  == '+':
        if trans_start_incmpl:
            cds_spec.start_incmpl = True
        if trans_end_incmpl or not txd.attrs.valid_stop:
            cds_spec.end_incmpl = True
    else:
        if trans_start_incmpl:
            cds_spec.end_incmpl = True
        if trans_end_incmpl or not txd.attrs.valid_stop:
            cds_spec.start_incmpl = True

    # factor in completeness based on source
    if txd.tx.cds_start_stat == "incmpl":
        cds_spec.start_incmpl = True
    if txd.tx.cds_end_stat == "incmpl":
        cds_spec.end_incmpl = True

    # If there is a not a valid start code, force start incompleteness.  This will
    # make non-ATG starts incomplete, however unless we add GENCODE non-ATG start
    # tags, this probably produces the fewest errors.
    if not txd.attrs.valid_start:
        if txd.tx.strand  == '+':
            cds_spec.start_incmpl = True
        else:
            cds_spec.end_incmpl = True

    # if there is a valid stop, force end completeness or tbl2asn generates SEQ_FEAT.PartialProblem
    # this overrides the source gene being incomplete.  If CDS has been shortened at the 3' end,
    # then that is done because we found a valid stop.
    #
    if txd.tx.strand  == '+':
        if txd.attrs.valid_stop or (txd.attrs.adj_stop < txd.tx.thick_stop):
            cds_spec.end_incmpl = False
    else:
        if txd.attrs.valid_stop or (txd.attrs.adj_start > txd.tx.thick_start):
            cds_spec.start_incmpl = False


def get_cds_regions(txd):
    """Get list of CDS regions and list of partial codons.  If there is a frame
    shift, adjust the CDS accordingly, inserting a gap. If this CDS interval set has been seen before, ignore it
    """
    cds_spec = CdsSpec(txd.tx.strand)
    adj_start = txd.attrs["adj_start"]
    adj_stop = txd.attrs["adj_stop"]
    if np.isnan(adj_start) or np.isnan(adj_stop):
        adj_start = txd.tx.thick_start
        adj_stop = txd.tx.thick_stop
    else:
        adj_start = int(adj_start)
        adj_stop = int(adj_stop)
    coding_interval = ChromosomeInterval(txd.tx.coding_interval.chromosome, adj_start, adj_stop, txd.tx.strand)
    expected_frame = 0
    # traverse in transcription order
    trans_start_incmpl = None
    for iexon in make_exon_idx_iter(txd):
        cds_interval = txd.tx.exon_intervals[iexon].intersection(coding_interval)
        if cds_interval is not None:
            if trans_start_incmpl is None:
                trans_start_incmpl = (txd.tx.exon_frames[iexon] != 0)
            expected_frame = add_cds_region(cds_interval, txd.tx.exon_frames[iexon], expected_frame, cds_spec)

    compute_cds_completeness(txd, trans_start_incmpl, cds_spec)
    # put into genomic order (required by exon splitting)
    cds_spec.regions = sorted(cds_spec.regions, key=lambda g:g.start)
    cds_spec.gaps = sorted(cds_spec.gaps, key=lambda g:g.start)
    return cds_spec


def is_possible_unitary_pseudo(txd, cds_spec):
    """call as pseudo if reduced in length by > 30bp or 20% or more of
    the CDS lost."""
    src_cds_len = txd.tx.cds_size
    mapped_cds_len = cds_spec.cds_len
    assert mapped_cds_len <= src_cds_len

    return (mapped_cds_len > 0) and ((mapped_cds_len - src_cds_len) > 30) or (mapped_cds_len <= (0.80 * src_cds_len))


def write_cds_regions(locus_tag_prefix, gene, txd, trans_num, cds_spec, tblwr):
    tblwr.write_feature(cds_spec.regions, "CDS", cds_spec.start_incmpl, cds_spec.end_incmpl)

    # FIXME: codon_start could be set rather than advancing to first codon
    tblwr.write_qualifier('codon_start', "1")

    tblwr.write_locus_tag(locus_tag_prefix, gene)

    # product and protein id  must be same on mRNA and CDS
    write_product(locus_tag_prefix, gene, txd, trans_num, tblwr)

    # coding transcripts must have a protein_id tag -- just use the name with _prot suffix
    tblwr.write_qualifier('transcript_id', get_transcript_gnl_id(locus_tag_prefix, txd))

    for gap in cds_spec.gaps:
        tblwr.write_note("a gap was added at {}-{} so that CDS maintains frame; possibly due to genomic or alignment errors".format(gap.start+1, gap.stop))


def write_cds(locus_tag_prefix, gene, txd, trans_num, cds_spec, tblwr):
    if len(cds_spec.regions):
        write_cds_regions(locus_tag_prefix, gene, txd, trans_num, cds_spec, tblwr)
    elif verbose:
        print("  cds-dropped", txd.tx.name, file=sys.stderr)


def get_transcript_exons(txd):
    return txd.tx.exon_intervals


def split_exon_at_gap(remains, gap, cds_spec):
    """possibly split exon at a gap.  Return split part and remains of exon"""
    inter = remains.intersection(gap)
    if inter is None:
        return None, remains  # not this exon's gap
    # if gap is at start or end of CDS, just let exon parts become UTR.
    if (gap.start == cds_spec.regions[0].start) or (gap.stop == cds_spec.regions[-1].stop):
        return None, remains  # gap becomes UTR

    split = ChromosomeInterval(gap.chromosome, remains.start, inter.start, gap.strand)
    new_remains = ChromosomeInterval(gap.chromosome, inter.stop, remains.stop, gap.strand)
    return split, new_remains


def add_gaps_to_exon(exon, cds_spec):
    """add gaps to one exon"""
    new_exon = []
    remains = exon
    # CDS and gaps in genomic order
    for gap in cds_spec.gaps:
        split, remains = split_exon_at_gap(remains, gap, cds_spec)
        if (split is not None) and (len(split) > 0):
            new_exon.append(split)
    if len(remains) > 0:
        new_exon.append(remains)
    return new_exon


def add_gaps_to_exons(exons, cds_spec):
    """edit exons so they are gapped where we added CDS gaps"""
    new_exons = []
    for exon in exons:
        new_exons.extend(add_gaps_to_exon(exon, cds_spec))
    return new_exons


def write_transcript_features(locus_tag_prefix, gene, txd, trans_num, exons, unitary_pseudo, tblwr):
    txa = txd.attrs
    feat = get_transcript_feature(txd)

    is_rrna_pseudo = txa.transcript_biotype in ['Mt_rRNA', 'rRNA'] and txa.source_gene_common_name not in ['5S_rRNA', '5_8S_rRNA']
    if unitary_pseudo or is_rrna_pseudo:
        feat = 'ncRNA'  # force into ncRNA for unitary duplicate case or rRNA pseudogene case

    # feature with correct type and all exons
    tblwr.write_feature(exons, feat)

    # don't include transcript_id on immuno region features
    if feat not in immuno_features:
        tblwr.write_qualifier('transcript_id', get_transcript_gnl_id(locus_tag_prefix, txd))

    # if non-coding, fill in ncRNA_class tag
    if is_non_coding(txd):
        tblwr.write_qualifier('ncRNA_class', ncrna_class[txa.transcript_biotype])

    if unitary_pseudo or is_rrna_pseudo:
        tblwr.write_qualifier('ncRNA_class', 'other')

    tblwr.write_locus_tag(locus_tag_prefix, gene)

    if not unitary_pseudo and not is_rrna_pseudo:
        write_product(locus_tag_prefix, gene, txd, trans_num, tblwr)

    # check for gaps
    if any(len(x) <= 50 for x in txd.tx.intron_intervals):
        tblwr.write_qualifier('exception', 'low-quality sequence region')

    if unitary_pseudo:
        tblwr.write_qualifier('pseudogene', "unitary")
        tblwr.write_note("Putative inactivated transcript (unitary pseudogene)")

    if is_rrna_pseudo:
        tblwr.write_qualifier('pseudogene', 'unknown')
        tblwr.write_note('rRNA pseudogene {}'.format(txa.source_gene_common_name))
    elif txa.transcript_biotype == 'rRNA':
        tblwr.write_note('rRNA {} in GENCODE annotation'.format(txa.source_gene_common_name))

    # record cat pipeline information
    tblwr.write_note("CAT transcript id: {}".format(get_transcript_base_id(txd)))
    tblwr.write_note("CAT alignment id: {}".format(txa.alignment_id))
    if not pd.isnull(txa.source_transcript):
        tblwr.write_note("CAT source transcript id: {}".format(txa.source_transcript))
    if txa.transcript_biotype != 'unknown_likely_coding':
        tblwr.write_note("CAT source GENCODE transcript biotype: {}".format(txa.transcript_biotype))
    else:
        tblwr.write_note('CAT novel prediction: {}'.format(txa.transcript_modes))


def include_transcript(cds_spec):
    # don't include too-short CDS transcripts
    return (cds_spec is None) or (cds_spec.cds_len >= min_cds_size)


def convert_transcript(locus_tag_prefix, gene, txd, trans_num, tblwr):
    if verbose:
        print("converting", txd.tx.name, file=sys.stderr)
    cds_spec = None
    unitary_pseudo = False
    if trans_has_cds(txd) and (txd.attrs.gene_biotype not in pseudo_map):
        cds_spec = get_cds_regions(txd)
        if cds_spec is not None:
            unitary_pseudo = is_possible_unitary_pseudo(txd, cds_spec)

    if not include_transcript(cds_spec):
        return False, False  # CDS truncated too much

    exons = get_transcript_exons(txd)
    if (cds_spec is not None) and (len(cds_spec.regions) > 0) and (len(cds_spec.regions) > 0):
        exons = add_gaps_to_exons(exons, cds_spec)

    write_transcript_features(locus_tag_prefix, gene, txd, trans_num, exons, unitary_pseudo, tblwr)
    if cds_spec is not None and unitary_pseudo is False:
        write_cds(locus_tag_prefix, gene, txd, trans_num, cds_spec, tblwr)
    return trans_has_cds(txd), unitary_pseudo


def write_gene_features(locus_tag_prefix, gene, unitary_pseudo, tblwr):
    txd = gene.txds[0]  # arbitrary transcript
    txa = txd.attrs

    # coordinates
    tblwr.write_feature([gene.region], "gene")

    tblwr.write_locus_tag(locus_tag_prefix, gene)

    if txa.gene_biotype == 'rRNA' and txa.source_transcript_name.startswith('RNA5SP'):
        tblwr.write_qualifier('pseudogene', 'unknown')

    # if this is a pseudogene, record the pseudogene tag on the gene-level feature
    if txa.gene_biotype in pseudo_map:
        tblwr.write_qualifier("pseudogene", pseudo_map[txa.gene_biotype])
    elif unitary_pseudo:
        tblwr.write_qualifier('pseudogene', "unitary")


def convert_gene_transcripts(locus_tag_prefix, gene, tblwr):
    num_with_cds = 0
    num_unitary_pseudo = 0
    for trans_num, txd in enumerate(gene.txds, 1):
        has_cds, unitary_pseudo = convert_transcript(locus_tag_prefix, gene, txd, trans_num, tblwr)
        if has_cds:
            num_with_cds += 1
        if unitary_pseudo:
            num_unitary_pseudo += 1
    return num_with_cds, num_unitary_pseudo


def convert_gene(locus_tag_prefix, gene, tblwr):
    num_with_cds, num_unitary_pseudo = convert_gene_transcripts(locus_tag_prefix, gene, tblwr)
    tblwr.push()  # save transcripts
    all_coding_unitary_pseudo = (num_with_cds > 0) and (num_unitary_pseudo == num_with_cds)
    write_gene_features(locus_tag_prefix, gene, all_coding_unitary_pseudo, tblwr)
    tblwr.flush()  # write genes
    tblwr.pop()
    tblwr.flush()  # write transcripts


def convert_chrom(locus_tag_prefix, chrom, by_gene_id, tblwr):
    """write features for a chromosome"""
    tblwr.start_seq(chrom)
    tblwr.flush()
    for gene in sorted(by_gene_id.values(), key=lambda g: (g.start, g.strand)):
        convert_gene(locus_tag_prefix, gene, tblwr)


def cat_to_ncbi_tbl(locus_tag_prefix, by_chrom_gene_id, tblwr):
    for chrom in sorted(by_chrom_gene_id.keys()):
        convert_chrom(locus_tag_prefix, chrom, by_chrom_gene_id[chrom], tblwr)


def main(args):
    check_defs()
    by_chrom_gene_id = load_annotations(args.cat_genepred, args.cat_genepred_info)
    with open(args.ncbi_tbl_file, "w") as tblwr:
        cat_to_ncbi_tbl(args.locus_tag_prefix, by_chrom_gene_id, NcbiTblWriter(tblwr))


main(parse_args())
